{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import itertools\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_size = 64\n",
    "n_epochs = 24\n",
    "batch_size = 64\n",
    "learning_rate = 0.0002\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Scale(img_size),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    dset.FashionMNIST('~/Datasets/', train = True, \n",
    "                      download = True, transform = transform),\n",
    "    batch_size = batch_size,\n",
    "    shuffle = True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using downloaded and verified file: fashion/FashionMNIST/raw/train-images-idx3-ubyte.gz\n",
      "Extracting fashion/FashionMNIST/raw/train-images-idx3-ubyte.gz to fashion/FashionMNIST/raw\n"
     ]
    },
    {
     "ename": "EOFError",
     "evalue": "Compressed file ended before the end-of-stream marker was reached",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mEOFError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-22-b06087ce9dee>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m train_loader = torch.utils.data.DataLoader(\n\u001b[1;32m      7\u001b[0m     dset.FashionMNIST('fashion', train = True, \n\u001b[0;32m----> 8\u001b[0;31m                       download = True, transform = transform),\n\u001b[0m\u001b[1;32m      9\u001b[0m     \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0mshuffle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[1;32m     66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     67\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36mdownload\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0murl\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murls\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    134\u001b[0m             \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrpartition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m             \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_folder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    137\u001b[0m         \u001b[0;31m# process and save as torch files\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[0;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[1;32m    250\u001b[0m     \u001b[0marchive\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    251\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Extracting {} to {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marchive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextract_root\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 252\u001b[0;31m     \u001b[0mextract_archive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marchive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextract_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mremove_finished\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    253\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mextract_archive\u001b[0;34m(from_path, to_path, remove_finished)\u001b[0m\n\u001b[1;32m    227\u001b[0m         \u001b[0mto_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplitext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrom_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    228\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mout_f\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgzip\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGzipFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrom_path\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mzip_f\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 229\u001b[0;31m             \u001b[0mout_f\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip_f\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    230\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0m_is_zip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrom_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    231\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mzipfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mZipFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrom_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'r'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/gzip.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, size)\u001b[0m\n\u001b[1;32m    274\u001b[0m             \u001b[0;32mimport\u001b[0m \u001b[0merrno\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    275\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mOSError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrno\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEBADF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"read() on write-only GzipFile object\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 276\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_buffer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    277\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    278\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mread1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/gzip.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, size)\u001b[0m\n\u001b[1;32m    480\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    481\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34mb\"\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 482\u001b[0;31m                 raise EOFError(\"Compressed file ended before the \"\n\u001b[0m\u001b[1;32m    483\u001b[0m                                \"end-of-stream marker was reached\")\n\u001b[1;32m    484\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mEOFError\u001b[0m: Compressed file ended before the end-of-stream marker was reached"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Scale(img_size),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    dset.FashionMNIST('fashion', train = True, \n",
    "                      download = True, transform = transform),\n",
    "    batch_size = batch_size,\n",
    "    shuffle = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "class discriminator_model(nn.Module):\n",
    "\n",
    "  def __init__(self):\n",
    "    super(discriminator_model, self).__init__()\n",
    "    self.conv1 = nn.Conv2d(1, 128, 4, 2, 1)\n",
    "    self.conv2 = nn.Conv2d(128, 256, 4, 2, 1)\n",
    "    self.conv2_bn = nn.BatchNorm2d(256)\n",
    "    self.conv3 = nn.Conv2d(256, 512, 4, 2, 1)\n",
    "    self.conv3_bn = nn.BatchNorm2d(512)\n",
    "    self.conv4 = nn.Conv2d(512, 1024, 4, 2, 1)\n",
    "    self.conv4_bn = nn.BatchNorm2d(1024)\n",
    "    self.conv5 = nn.Conv2d(1024, 1, 4, 1, 0)\n",
    "    \n",
    "  def weight_init(self):\n",
    "    for m in self._modules:\n",
    "      normal_init(self._modules[m])\n",
    "      \n",
    "  def forward(self, input):\n",
    "    x = F.leaky_relu(self.conv1(input), 0.2)\n",
    "    x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)\n",
    "    x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)\n",
    "    x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)\n",
    "    x = F.sigmoid(self.conv5(x))\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 64-bit ('base': conda)",
   "language": "python",
   "name": "python37464bitbaseconda5073072be3c4479bad20c57a46be46a1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
